from __future__ import annotations


from e2cnn.group import Group
from e2cnn.group import IrreducibleRepresentation, Representation
from e2cnn.group import utils

import numpy as np
import math

from typing import List, Tuple, Callable, Iterable


__all__ = ["CyclicGroup"]


_cached_group_instances = {}


class CyclicGroup(Group):
    
    def __init__(self, N: int):
        r"""
        Build an instance of the cyclic group :math:`C_N` which contains :math:`N` discrete planar rotations.
        
        The group elements are :math:`\{e, r, r^2, r^3, \dots, r^{N-1}\}`, with group law
        :math:`r^a \cdot r^b = r^{\ a + b \!\! \mod \!\! N \ }`.
        The cyclic group :math:`C_N` is isomorphic to the integers *modulo* ``N``.
        For this reason, elements are stored as the integers between :math:`0` and :math:`N-1`, where the :math:`k`-th
        element can also be interpreted as the discrete rotation by :math:`k\frac{2\pi}{N}`.
        
        Args:
            N (int): order of the group
            
        """
        
        assert (isinstance(N, int) and N > 0)
        
        super(CyclicGroup, self).__init__("C%d" % N, False, True)
        
        self.elements = list(range(N))

        self.elements_names = ['e'] + ['r%d' % i for i in range(1, N)]

        self.identity = 0
        
        self._build_representations()
        
    def inverse(self, element: int) -> int:
        r"""
        Return the inverse element :math:`r^{-j \mod N}` of the input element :math:`r^j`, specified by the input
        integer :math:`j` (``element``)
        
        Args:
            element (int): a group element :math:`r^j`

        Returns:
            its opposite :math:`r^{-j \mod N}`
            
        """
        return (-element) % self.order()

    def combine(self, e1: int, e2: int) -> int:
        r"""
        Return the composition of the two input elements.
        Given two integers :math:`a` and :math:`b` representing the elements :math:`r^a` and :math:`r^b`, the method
        returns the integer :math:`a + b \mod N` representing the element :math:`r^{a + b \mod N}`.
        

        Args:
            e1 (int): a group element :math:`r^a`
            e2 (int): another group element :math:`r^a`

        Returns:
            their composition :math:`r^{a+b \mod N}`
            
        """
        return (e1 + e2) % self.order()

    def equal(self, e1: int, e2: int) -> bool:
        r"""

        Check if the two input values corresponds to the same element.

        Args:
            e1 (int): an element
            e2 (int): another element

        Returns:
            whether they are the same element

        """
        return e1 == e2
    
    def is_element(self, element: int) -> bool:
        if isinstance(element, int):
            return 0 <= element < self.order()
        else:
            return False

    def testing_elements(self) -> Iterable[int]:
        r"""
        A finite number of group elements to use for testing.
        
        """
        return iter(self.elements)

    def __eq__(self, other):
        if not isinstance(other, CyclicGroup):
            return False
        else:
            return self.name == other.name and self.order() == other.order()

    def subgroup(self, id: int) -> Tuple[Group, Callable, Callable]:
        r"""
        Restrict the current group to the cyclic subgroup :math:`C_M`.
        If the current group is :math:`C_N`, it restricts to the subgroup generated by :math:`r^{(N/M)}`.
        Notice that :math:`M` has to divide the order :math:`N` of the current group.
        
        The method takes as input the integer :math:`M` identifying of the subgroup to build (the order of the subgroup)
        
        Args:
            id (int): the integer :math:`M` identifying of the subgroup

        Returns:
            a tuple containing

                - the subgroup,

                - a function which maps an element of the subgroup to its inclusion in the original group and

                - a function which maps an element of the original group to the corresponding element in the subgroup (returns None if the element is not contained in the subgroup)
                
        """

        assert isinstance(id, int)

        order = id

        assert self.order() % order == 0, \
            "Error! The subgroups of a cyclic group have an order that divides the order of the supergroup." \
            " %d does not divide %d " % (order, self.order())

        if id not in self._subgroups:
    
            # Build the subgroup
            ratio = self.order()//order
            
            # take the elements of the group generated by "r^ratio"
            sg = CyclicGroup(order)
    
            parent_mapping = lambda e, ratio=ratio: e * ratio
            child_mapping = lambda e, ratio=ratio: None if e % ratio != 0 else int(e // ratio)
            
            self._subgroups[id] = sg, parent_mapping, child_mapping
    
        return self._subgroups[id]

    def _restrict_irrep(self, irrep: str, id: int) -> Tuple[np.matrix, List[str]]:
        r"""
        
        Restrict the input irrep to the subgroup :math:`C_m` with order ``m``.
        If the current group is :math:`C_n`, it restricts to the subgroup generated by :math:`r^{(n/m)}`.
        Notice that :math:`m` has to divide the order :math:`n` of the current group.
        
        The method takes as input the integer :math:`m` identifying of the subgroup to build (the order of the subgroup)

        Args:
            irrep (str): the name/identifier of the irrep to restrict
            id (int): the integer ``m`` identifying the subgroup

        Returns:
            a pair containing the change of basis and the list of irreps of the subgroup which appear in the restricted irrep
            
        """
    
        irr = self.irreps[irrep]
    
        # Build the subgroup
        sg, _, _ = self.subgroup(id)
    
        order = id
    
        change_of_basis = None
        irreps = []
    
        f = irr.attributes["frequency"] % order
    
        if f > order/2:
            f = order - f
            change_of_basis = np.array([[1, 0], [0, -1]])
        else:
            change_of_basis = np.eye(irr.size)
    
        r = f"irrep_{f}"
    
        irreps.append(r)
        if sg.irreps[r].size < irr.size:
            irreps.append(r)
        
        return change_of_basis, irreps

    def _build_representations(self):
        r"""
        Build the irreps and the regular representation for this group
        
        """
        
        N = self.order()

        # Build all the Irreducible Representations
        for k in range(0, int(N // 2) + 1):
            self.irrep(k)
            
        # Build all Representations

        # add all the irreps to the set of representations already built for this group
        self.representations.update(**self.irreps)

        # build the regular representation
        self.representations['regular'] = self.regular_representation
        self.representations['regular'].supported_nonlinearities.add('vectorfield')

    def _build_quotient_representations(self):
        r"""
        Build all the quotient representations for this group

        """
        for n in range(2, int(math.ceil(math.sqrt(self.order())))):
            if self.order() % n == 0:
                self.quotient_representation(n)
    
    @property
    def trivial_representation(self) -> Representation:
        return self.representations['irrep_0']

    def irrep(self, k: int) -> IrreducibleRepresentation:
        r"""
        Build the irrep of frequency ``k`` of the current cyclic group.
        The frequency has to be a non-negative integer in :math:`\{0, \dots, \left \lfloor N/2 \right \rfloor \}`,
        where :math:`N` is the order of the group.
        
        Args:
            k (int): the frequency of the representation

        Returns:
            the corresponding irrep

        """
        assert 0 <= k <= self.order()//2

        name = f"irrep_{k}"
        
        if name not in self.irreps:
        
            n = self.order()
    
            base_angle = 2.0 * np.pi / n
            
            if k == 0:
                # Trivial representation
            
                irrep = lambda element, identity=np.eye(1): identity
                character = lambda e: 1
                supported_nonlinearities = ['pointwise', 'gate', 'norm', 'gated', 'concatenated']
                self.irreps[name] = IrreducibleRepresentation(self, name, irrep, 1, 1,
                                                              supported_nonlinearities=supported_nonlinearities,
                                                              # character=character,
                                                              # trivial=True,
                                                              frequency=k)
            elif n % 2 == 0 and k == int(n/2):
                # 1 dimensional Irreducible representation (only for even order groups)
                irrep = lambda element, k=k, base_angle=base_angle: np.array([[np.cos(k * element * base_angle)]])
                supported_nonlinearities = ['norm', 'gated', 'concatenated']
                self.irreps[name] = IrreducibleRepresentation(self, name, irrep, 1, 1,
                                                              supported_nonlinearities=supported_nonlinearities,
                                                              frequency=k)
            else:
                # 2 dimensional Irreducible Representations
                
                # build the rotation matrix with rotation frequency 'frequency'
                irrep = lambda element, k=k, base_angle=base_angle: utils.psi(element * base_angle, k=k)
            
                supported_nonlinearities = ['norm', 'gated']
                self.irreps[name] = IrreducibleRepresentation(self, name, irrep, 2, 2,
                                                              supported_nonlinearities=supported_nonlinearities,
                                                              frequency=k)
        return self.irreps[name]

    @staticmethod
    def _generator(N: int) -> 'CyclicGroup':
        global _cached_group_instances
        if N not in _cached_group_instances:
            _cached_group_instances[N] = CyclicGroup(N)
        
        return _cached_group_instances[N]

